import random
import numpy as np
from torchvision import datasets,transforms

def cifar_global(dataset,frac):
    num_items = int(len(dataset)*frac)
    all_idxs = [i for i in range(len(dataset))]
    if num_items == 0 :
        dict_global = []
    else:
        dict_global = np.random.choice(all_idxs,num_items,replace=False)
    return dict_global

def mnist_iid(dataset, num_users):
    return iid(dataset, num_users)


def mnist_noniid(dataset, num_users, case=1):
    num_shards, num_imgs = 100, 600
    return non_iid(dataset, num_users, num_shards, num_imgs, case)


def fashion_mnist_iid(dataset, num_users,dict_global):
    return iid(dataset, num_users,dict_global)


def fashion_mnist_noniid(dataset, num_users, case=1):
    num_shards, num_imgs = 100, 600
    return non_iid(dataset, num_users, num_shards, num_imgs, case)

def cifar_iid(dataset,num_users,dict_global):
    return iid(dataset,num_users,dict_global)

def iid(dataset, num_users,dict_global):
    num_items = int((len(dataset) - len(dict_global)) / num_users)
    print(num_items)
    # dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    dict_users, all_idxs = {}, list(set([i for i in range(len(dataset))]) - set(dict_global))
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])

    for i in range(num_users):
        dict_users[i] = np.array(list(dict_users[i])).tolist()
    return dict_users

def cifar_noniid(dataset,num_users,case=1):
    num_shards, num_imgs = 100, 500
    return non_iid(dataset,num_users,num_shards,num_imgs,case)


def non_iid(dataset, num_users, num_shards, num_imgs, case=1):
    if case == 1:
        return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs)
    elif case == 2:
        return noniid_label_2(dataset, num_users, int(num_shards * 2), int(num_imgs / 2))
    elif case == 3:
        return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=0.8)
    elif case == 4:
        return noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=0.5)
    else:
        exit('Error: unrecognized noniid case')

def noniid_ratio_r_label_1(dataset, num_users, num_shards, num_imgs, ratio=1):
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards * num_imgs)
    labels = dataset.targets

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 1, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:int((rand + ratio) * num_imgs)]),
                                           axis=0)
            random.shuffle(dict_users[i])

    if ratio < 1:
        rest_idxs = np.array([], dtype='int64')
        idx_shard = [i for i in range(num_shards)]
        for i in idx_shard:
            rest_idxs = np.concatenate((rest_idxs, idxs[int((i + ratio) * num_imgs):(i + 1) * num_imgs]), axis=0)
        num_items = int(len(dataset) / num_users * (1 - ratio))
        for i in range(num_users):
            rest_to_add = set(np.random.choice(rest_idxs, num_items, replace=False))
            dict_users[i] = np.concatenate((dict_users[i], list(rest_to_add)), axis=0)
            rest_idxs = list(set(rest_idxs) - rest_to_add)
            random.shuffle(dict_users[i])

    for i in range(num_users):
        dict_users[i] = dict_users[i].tolist()

    return dict_users


def noniid_label_2(dataset, num_users, num_shards, num_imgs):
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards * num_imgs)
    labels = dataset.targets

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    for i in range(num_users):
        len_idx_shard = len(idx_shard)
        rand1 = np.random.choice(idx_shard[0:int(len_idx_shard / 2)], 1, replace=False)[0]
        rand2 = np.random.choice(idx_shard[int(len_idx_shard / 2):len_idx_shard], 1, replace=False)[0]
        idx_shard = list(set(idx_shard) - set([rand1, rand2]))
        for rand in [rand1, rand2]:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:int((rand + 1) * num_imgs)]), axis=0)
            random.shuffle(dict_users[i])
    return dict_users
